feat(pt): add hard-coded aparam output gate for fitting nets#5495
feat(pt): add hard-coded aparam output gate for fitting nets#5495Jingbei-Bai wants to merge 8 commits into
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds an optional “aparam output gate” to fitting networks (PyTorch + array-API/dpmodel) and exposes it through configuration/argcheck, with tests and an example config update.
Changes:
- Introduce
use_aparam_output_gate,aparam_gate_norm, andaparam_gate_clamparguments, including validation + serialization. - Apply the gate in PT fitting forward paths and dpmodel array-API fitting path.
- Add a dedicated PT unit test and update an example training input JSON.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/model/test_aparam_output_gate.py | Adds unit tests validating gating behavior + serialize roundtrip |
| examples/fparam/train/input_aparam.json | Demonstrates new config knobs in an example input |
| deepmd/utils/argcheck.py | Exposes new fitting arguments and documentation strings |
| deepmd/pt/model/task/sezm_ener.py | Extracts raw aparam and applies output gate in SE(Z/M) energy path |
| deepmd/pt/model/task/fitting.py | Implements gate computation/application and wires it into common forward path |
| deepmd/dpmodel/fitting/general_fitting.py | Mirrors gate logic for dpmodel/array-API fitting path |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| aparam_gate_norm=1.0, | ||
| aparam_gate_clamp=True, | ||
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) |
| aparam_gate_norm=norm, | ||
| aparam_gate_clamp=False, | ||
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) |
|
|
||
| fitting_gate = fitting._compute_aparam_output_gate(aparam) | ||
| expected = (a_val * a_val) / (sigma * sigma * norm) | ||
| self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype))) |
| if aparam.numel() % (nf * self.numb_aparam) != 0: | ||
| raise ValueError( | ||
| f"input aparam: cannot reshape {list(aparam.shape)} " | ||
| f"into ({nf}, nloc, {self.numb_aparam})." | ||
| ) | ||
| aparam_raw = aparam.view([nf, -1, self.numb_aparam]) |
| if aparam.numel() % (nf * self.numb_aparam) != 0: | ||
| raise ValueError( | ||
| f"input aparam: cannot reshape {list(aparam.shape)} " | ||
| f"into ({nf}, nloc, {self.numb_aparam})." | ||
| ) | ||
| aparam_raw = aparam.view([nf, -1, self.numb_aparam]) |
| self, | ||
| aparam_raw: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Hard-coded gate g = a^2 / (sigma^2 * norm) from raw aparam.""" |
| if self.numb_aparam > 1: | ||
| gate = gate.prod(dim=-1, keepdim=True) | ||
| if self.aparam_gate_clamp: | ||
| gate = gate.clamp(0.0, 1.0) |
for more information, see https://pre-commit.ci
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an optional multiplicative "aparam output gate" for fitting outputs and a new variational-Gaussian descriptor (se_a_vg). Implements gate compute/apply helpers in dpmodel and PT, wires gating into atomic output flow and SeZM, adds args/examples/tests, and adds VG descriptor code, exports, and tests. ChangesAparam Output Gate Feature
se_a_vg Variational-Gaussian Descriptor
Estimated code review effort 🎯 4 (Complex) | ⏱️ ~45 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@source/tests/pt/model/test_aparam_output_gate.py`:
- Line 73: The assertion compares fitting_gate (which lives on env.DEVICE) to a
CPU tensor; create the expected tensor on the same device to avoid
device-mismatch failures: when constructing torch.tensor(expected, dtype=dtype)
in the test (the line comparing fitting_gate), pass device=env.DEVICE or call
.to(env.DEVICE) so the expected tensor matches fitting_gate's device before
calling torch.allclose.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 73a16adf-a7d6-43fb-adcb-ab9b7391afb8
📒 Files selected for processing (6)
deepmd/dpmodel/fitting/general_fitting.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/task/sezm_ener.pydeepmd/utils/argcheck.pyexamples/fparam/train/input_aparam.jsonsource/tests/pt/model/test_aparam_output_gate.py
|
|
||
| fitting_gate = fitting._compute_aparam_output_gate(aparam) | ||
| expected = (a_val * a_val) / (sigma * sigma * norm) | ||
| self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype))) |
There was a problem hiding this comment.
Fix expected tensor device in gate-formula assertion.
At Line 73, torch.tensor(expected, dtype=dtype) is created on CPU, while fitting_gate is on env.DEVICE, which can fail on GPU/MPS.
Proposed fix
- self.assertTrue(torch.allclose(fitting_gate, torch.tensor(expected, dtype=dtype)))
+ self.assertTrue(
+ torch.allclose(
+ fitting_gate,
+ torch.tensor(expected, dtype=dtype, device=device),
+ )
+ )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@source/tests/pt/model/test_aparam_output_gate.py` at line 73, The assertion
compares fitting_gate (which lives on env.DEVICE) to a CPU tensor; create the
expected tensor on the same device to avoid device-mismatch failures: when
constructing torch.tensor(expected, dtype=dtype) in the test (the line comparing
fitting_gate), pass device=env.DEVICE or call .to(env.DEVICE) so the expected
tensor matches fitting_gate's device before calling torch.allclose.
| ).to(device) | ||
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) | ||
|
|
||
| descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) |
| fitting.aparam_inv_std.copy_(torch.tensor([1.0 / sigma], dtype=dtype)) | ||
|
|
||
| descriptor = torch.randn(nf, nloc, dim_descrpt, dtype=dtype, device=device) | ||
| atype = torch.zeros(nf, nloc, dtype=torch.int64, device=device) |
Gate previously multiplied only the fitting output before apply_out_stat added per-type out_bias, so sigma=0 inference still returned non-zero energy. Apply the gate after out_bias in the atomic model forward path. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/task/fitting.py`:
- Around line 807-813: The current validation only checks aparam_raw.shape[-1]
against self.numb_aparam but does not ensure the leading dims match outs, which
can hide shape mismatches; update the block handling aparam in the method (the
aparam_raw conversion and return of _apply_aparam_output_gate) to explicitly
attempt to reshape aparam_raw to (outs.shape[0], outs.shape[1],
self.numb_aparam) and if that reshape is impossible raise a clear ValueError
describing the expected shape (using outs.shape[0], outs.shape[1],
self.numb_aparam), otherwise use the reshaped tensor for the subsequent call to
_apply_aparam_output_gate so broadcasting/multiplication errors are avoided.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: c9aa14cb-8701-4b9f-9544-1e433e241562
📒 Files selected for processing (6)
deepmd/dpmodel/atomic_model/base_atomic_model.pydeepmd/dpmodel/fitting/general_fitting.pydeepmd/pt/model/atomic_model/base_atomic_model.pydeepmd/pt/model/task/fitting.pydeepmd/pt/model/task/sezm_ener.pysource/tests/pt/model/test_aparam_output_gate.py
💤 Files with no reviewable changes (1)
- deepmd/pt/model/task/sezm_ener.py
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/pt/model/test_aparam_output_gate.py
Introduce variational-Gaussian smooth descriptor (se_a_vg) where sigma enters the radial kernel and a fifth environment-matrix column per VGM II. Wire aparam through the atomic model, register argcheck/compression hooks, and keep fitting output gate unchanged. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)
2285-2290:⚠️ Potential issue | 🟠 Major | ⚡ Quick winEnforce a positive lower bound for
aparam_gate_norm.Line 2285 currently accepts
0/negative values, but this factor is used in a denominator; that can cause divide-by-zero or invalid gate scaling at runtime. Add schema validation here to fail fast.Proposed fix
Argument( "aparam_gate_norm", float, optional=True, default=1.0, + extra_check=lambda x: x > 0.0, + extra_check_errmsg="must be greater than 0", doc=doc_aparam_gate_norm, ),🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/utils/argcheck.py` around lines 2285 - 2290, The schema entry for "aparam_gate_norm" currently allows zero/negative values which can cause divide-by-zero; update the argument schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a strict positive lower bound (e.g., min > 0) or add a validator that raises an error if the provided value is <= 0 so the check fails fast; locate the schema definition containing "aparam_gate_norm", adjust its validation rules (or add a custom validator function) to reject non-positive values while keeping the default=1.0.
🧹 Nitpick comments (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
290-291: ⚡ Quick winAvoid per-call
inspect.signature(...)in the forward hot path.Line 290 recomputes the signature every batch. Cache whether
aparamis supported once in__init__and reuse it inforward_atomic.Proposed refactor
class DPAtomicModel(BaseAtomicModel): @@ def __init__( @@ self.eval_descriptor_list = [] self.eval_fitting_last_layer_list = [] + self._descriptor_accepts_aparam = ( + "aparam" in inspect.signature(self.descriptor.forward).parameters + ) @@ - if "aparam" in inspect.signature(self.descriptor.forward).parameters: + if self._descriptor_accepts_aparam: descriptor_kwargs["aparam"] = aparam🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/atomic_model/dp_atomic_model.py` around lines 290 - 291, The code currently calls inspect.signature(self.descriptor.forward) inside forward_atomic for every batch; instead determine once in __init__ whether the descriptor.forward accepts "aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in inspect.signature(self.descriptor.forward).parameters) and then in forward_atomic replace the per-call inspect.signature check with that cached boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any initialization path that sets self.descriptor to ensure the cached flag is computed after descriptor is assigned.deepmd/pt/model/descriptor/se_a.py (1)
325-365: ⚡ Quick winDocument the
aparamparameter in the docstring.The
forwardmethod now accepts anaparamparameter (line 325) but it is not documented in the docstring (lines 333-364). Even thoughDescrptSeAdoes not use this parameter (it's immediately deleted on line 365), documenting it helps maintain API clarity and aids developers who may reference this signature.📝 Suggested docstring addition
comm_dict The data needed for communication for parallel inference. + aparam + Atomic parameters. Not used by this descriptor; accepted for + interface compatibility with aparam-aware descriptors. Returns🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt/model/descriptor/se_a.py` around lines 325 - 365, The forward method signature includes an unused parameter aparam that is deleted immediately; update the forward docstring to include a brief description of aparam (type torch.Tensor | None, optional), its purpose or note that it is accepted for API compatibility and currently unused, and mention that it will be ignored (or deleted) within DescrptSeA.forward to clarify behavior for callers and maintainers; locate the method by the forward function in this module and add the aparam entry to the existing Parameters section.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Line 48: sigma_loc is being cast to nlist.dtype which is integer and truncates
fractional aparam values; remove the cast to nlist.dtype and instead preserve
aparam's floating dtype (or if device alignment is needed, cast only the device:
.to(device=nlist.device) or .to(dtype=aparam.dtype, device=nlist.device)).
Update the assignment for sigma_loc (from aparam[:, :nloc, 0]) to stop
converting to nlist.dtype so sigma retains its fractional values.
In `@deepmd/pt/model/descriptor/se_a_vg.py`:
- Around line 603-609: The serialization currently hardcodes "trainable": True;
change it to emit the actual flag from the descriptor object (e.g., use
obj.trainable or getattr(obj, "trainable", True)) in the serialize() output so a
descriptor created with trainable=False round-trips correctly; update the entry
replacing the literal True with the object's trainable attribute in the block
that builds the dict (the one referencing "`@variables`", "type_map", and
"type_one_side").
- Around line 282-305: The aparam tensor (aparam_t) is only sized for local
atoms but prod_env_mat_vg expects per-atom parameters for the extended atom set
used by nlist; expand aparam to the extended-atom domain before calling
prod_env_mat_vg by mapping local aparam values to extended indices (use the same
mapping that produced extended_atype/extended_coord from
extend_input_and_build_neighbor_list) to create extended_aparam (matching
extended_atype.shape/length) and pass that instead of aparam_t; apply the same
fix at the second call site around prod_env_mat_vg later in the file (the block
at ~378-401) so halo/periodic neighbors use the correct VG parameters.
- Around line 134-140: Normalize sel to a list before computing its length: move
or duplicate the normalization (self.sel = sel if isinstance(sel, list) else
[sel]) so that self.ntypes is set from len(self.sel) (not len(sel)), and then
compute self.sec using np.cumsum(self.sel); update the assignments around
self.ntypes, self.sel, and self.sec in __init__ (or set self.ntypes =
len(self.sel) immediately after the existing self.sel assignment) to avoid
calling len() on an int.
---
Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 2285-2290: The schema entry for "aparam_gate_norm" currently
allows zero/negative values which can cause divide-by-zero; update the argument
schema in deepmd/utils/argcheck.py for the "aparam_gate_norm" field to enforce a
strict positive lower bound (e.g., min > 0) or add a validator that raises an
error if the provided value is <= 0 so the check fails fast; locate the schema
definition containing "aparam_gate_norm", adjust its validation rules (or add a
custom validator function) to reject non-positive values while keeping the
default=1.0.
---
Nitpick comments:
In `@deepmd/pt/model/atomic_model/dp_atomic_model.py`:
- Around line 290-291: The code currently calls
inspect.signature(self.descriptor.forward) inside forward_atomic for every
batch; instead determine once in __init__ whether the descriptor.forward accepts
"aparam" (e.g. set self._descriptor_supports_aparam = "aparam" in
inspect.signature(self.descriptor.forward).parameters) and then in
forward_atomic replace the per-call inspect.signature check with that cached
boolean to conditionally add descriptor_kwargs["aparam"] = aparam; update any
initialization path that sets self.descriptor to ensure the cached flag is
computed after descriptor is assigned.
In `@deepmd/pt/model/descriptor/se_a.py`:
- Around line 325-365: The forward method signature includes an unused parameter
aparam that is deleted immediately; update the forward docstring to include a
brief description of aparam (type torch.Tensor | None, optional), its purpose or
note that it is accepted for API compatibility and currently unused, and mention
that it will be ignored (or deleted) within DescrptSeA.forward to clarify
behavior for callers and maintainers; locate the method by the forward function
in this module and add the aparam entry to the existing Parameters section.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e1861c11-00bd-407a-bffa-34f107386e4f
📒 Files selected for processing (8)
deepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/descriptor/__init__.pydeepmd/pt/model/descriptor/env_mat_vg.pydeepmd/pt/model/descriptor/se_a.pydeepmd/pt/model/descriptor/se_a_vg.pydeepmd/pt/model/model/__init__.pydeepmd/utils/argcheck.pysource/tests/pt/model/test_se_a_vg.py
| self.ntypes = len(sel) | ||
| self.type_one_side = type_one_side | ||
| self.seed = seed | ||
| self.reinit_exclude(exclude_types) | ||
|
|
||
| self.sel = sel if isinstance(sel, list) else [sel] | ||
| self.sec = [0, *np.cumsum(self.sel).tolist()] |
There was a problem hiding this comment.
Normalize scalar sel before calling len().
The public API accepts sel: int | list[int], but line 134 evaluates len(sel) before the scalar case is converted on line 139. sel=64 will fail in __init__.
Suggested fix
- self.ntypes = len(sel)
+ self.sel = sel if isinstance(sel, list) else [sel]
+ self.ntypes = len(self.sel)
self.type_one_side = type_one_side
self.seed = seed
self.reinit_exclude(exclude_types)
-
- self.sel = sel if isinstance(sel, list) else [sel]
self.sec = [0, *np.cumsum(self.sel).tolist()]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.ntypes = len(sel) | |
| self.type_one_side = type_one_side | |
| self.seed = seed | |
| self.reinit_exclude(exclude_types) | |
| self.sel = sel if isinstance(sel, list) else [sel] | |
| self.sec = [0, *np.cumsum(self.sel).tolist()] | |
| self.sel = sel if isinstance(sel, list) else [sel] | |
| self.ntypes = len(self.sel) | |
| self.type_one_side = type_one_side | |
| self.seed = seed | |
| self.reinit_exclude(exclude_types) | |
| self.sec = [0, *np.cumsum(self.sel).tolist()] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 134 - 140, Normalize sel
to a list before computing its length: move or duplicate the normalization
(self.sel = sel if isinstance(sel, list) else [sel]) so that self.ntypes is set
from len(self.sel) (not len(sel)), and then compute self.sec using
np.cumsum(self.sel); update the assignments around self.ntypes, self.sel, and
self.sec in __init__ (or set self.ntypes = len(self.sel) immediately after the
existing self.sel assignment) to avoid calling len() on an int.
| aparam_t = torch.tensor( | ||
| aparam_np[ff], dtype=self.prec, device=env.DEVICE | ||
| ).reshape(1, nloc, 1) | ||
| extended_coord, extended_atype, _, nlist = ( | ||
| extend_input_and_build_neighbor_list( | ||
| coord_t, | ||
| atype_t, | ||
| self.rcut, | ||
| self.sel, | ||
| mixed_types=False, | ||
| box=box_t, | ||
| ) | ||
| ) | ||
| env_mat, _, _ = prod_env_mat_vg( | ||
| extended_coord, | ||
| nlist, | ||
| extended_atype[:, :nloc], | ||
| aparam_t, | ||
| self.mean, | ||
| torch.ones_like(self.stddev), | ||
| self.rcut, | ||
| self.rcut_smth, | ||
| protection=self.env_protection, | ||
| ) |
There was a problem hiding this comment.
Expand aparam to the extended-atom domain before building the VG env-mat.
Both call sites hand prod_env_mat_vg() an aparam tensor sized only for local atoms, while nlist indexes extended_coord. That makes every periodic/halo neighbor fall back to the zero-filled sigma path, so training stats and runtime descriptors diverge from the intended VG kernel whenever an extended image is selected. Please build an extended_aparam with the same mapping used for the extended coordinates before calling the env-mat helper.
Also applies to: 378-401
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 282 - 305, The aparam
tensor (aparam_t) is only sized for local atoms but prod_env_mat_vg expects
per-atom parameters for the extended atom set used by nlist; expand aparam to
the extended-atom domain before calling prod_env_mat_vg by mapping local aparam
values to extended indices (use the same mapping that produced
extended_atype/extended_coord from extend_input_and_build_neighbor_list) to
create extended_aparam (matching extended_atype.shape/length) and pass that
instead of aparam_t; apply the same fix at the second call site around
prod_env_mat_vg later in the file (the block at ~378-401) so halo/periodic
neighbors use the correct VG parameters.
| "@variables": { | ||
| "davg": obj["davg"].detach().cpu().numpy(), | ||
| "dstd": obj["dstd"].detach().cpu().numpy(), | ||
| }, | ||
| "type_map": self.type_map, | ||
| "trainable": True, | ||
| "type_one_side": obj.type_one_side, |
There was a problem hiding this comment.
Serialize the real trainable flag.
serialize() always emits "trainable": True, so a descriptor created with trainable=False will round-trip as trainable after save/load or compression setup.
Suggested fix
- "trainable": True,
+ "trainable": obj.trainable,🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/se_a_vg.py` around lines 603 - 609, The
serialization currently hardcodes "trainable": True; change it to emit the
actual flag from the descriptor object (e.g., use obj.trainable or getattr(obj,
"trainable", True)) in the serialize() output so a descriptor created with
trainable=False round-trips correctly; update the entry replacing the literal
True with the object's trainable attribute in the block that builds the dict
(the one referencing "`@variables`", "type_map", and "type_one_side").
njzjz-bot
left a comment
There was a problem hiding this comment.
I found two blocking issues in the new VG env-mat path. Both are inline below.
— OpenClaw 2026.5.28 (model: custom-chat-jinzhezeng-group/gpt-5.5)
| ) -> torch.Tensor: | ||
| """Map per-atom aparam to neighbor-list sigma values.""" | ||
| nf, _, nnei = nlist.shape | ||
| sigma_loc = aparam[:, :nloc, 0].to(dtype=nlist.dtype) |
There was a problem hiding this comment.
This casts sigma values to nlist.dtype (int64), so fractional aparam/sigma values such as 0.5 become 0 before building sigma_ij. That makes the VG descriptor ignore non-integer sigma values and also explains why the current tests only compare 0 vs 1. Please keep this in the floating dtype/device of aparam instead of the neighbor-list dtype.
| device=sigma_loc.device, | ||
| ) | ||
| sigma_ext[:, :nloc] = sigma_loc | ||
| index = nlist.reshape(nf, -1) |
There was a problem hiding this comment.
nlist can contain -1 padding entries, but this gathers with the raw neighbor list. torch.gather does not accept negative indices, so any padded neighbor list will raise at runtime. _make_env_mat_vg already builds nlist_safe; this path needs the same masking/safe-index handling before gathering sigma values.
TorchScript-safe descriptor aparam forwarding, gate reshape parity with dpmodel, and env_mat_vg JIT/stat fixes so frozen.pth matches checkpoint eval. Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt/model/descriptor/env_mat_vg.py`:
- Around line 66-76: The _gather_neighbor_sigma routine currently leaves
ghost/periodic-image slots in sigma_ext as zeros which makes torch.gather return
zero sigmas for ghost neighbors; change the logic so ghost indices (positions >=
nloc) are populated by the source-atom sigma via the existing mapping (or remap
nlist to local-source indices before gathering) instead of zeros: use mapping to
copy sigma_loc into sigma_ext at ghost positions (or rewrite index = nlist_safe
to map ghost entries to the corresponding local atom indices) so sigma_nei
reflects the mapped source sigma; keep the original device/dtype handling and
return semantics unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 42c29a36-04a3-4ebd-a59a-996c070bf6c1
📒 Files selected for processing (6)
deepmd/pt/model/atomic_model/dp_atomic_model.pydeepmd/pt/model/descriptor/env_mat_vg.pydeepmd/pt/model/descriptor/se_a_vg.pydeepmd/pt/model/task/fitting.pysource/tests/pt/model/test_aparam_output_gate.pysource/tests/pt/model/test_se_a_vg.py
🚧 Files skipped from review as they are similar to previous changes (3)
- source/tests/pt/model/test_aparam_output_gate.py
- deepmd/pt/model/task/fitting.py
- deepmd/pt/model/descriptor/se_a_vg.py
| sigma_ext = torch.zeros( | ||
| (nf, nall + 1), | ||
| dtype=sigma_loc.dtype, | ||
| device=sigma_loc.device, | ||
| ) | ||
| sigma_ext[:, :nloc] = sigma_loc | ||
| nlist_safe = torch.where(mask, nlist, nall).to(torch.int64) | ||
| index = nlist_safe.reshape(nf, -1) | ||
| sigma_nei = torch.gather(sigma_ext, 1, index) | ||
| sigma_nei = sigma_nei.view(nf, nloc, nnei) | ||
| return torch.where(mask, sigma_nei, torch.zeros_like(sigma_nei)) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if mapping is used elsewhere for ghost->local lookup in VG code
rg -n "mapping" deepmd/pt/model/descriptor/se_a_vg.py deepmd/pt/model/descriptor/env_mat_vg.py
# Check test fixture to see if ghost atoms are present
ast-grep --pattern 'class TestCaseSingleFrameWithNlist'Repository: deepmodeling/deepmd-kit
Length of output: 19223
🌐 Web query:
deepmd-kit variational Gaussian descriptor sigma ghost atoms
💡 Result:
In the context of DeePMD-kit, the terms variational Gaussian descriptor, sigma, and ghost atoms do not constitute a single, standard feature or configuration parameter set within the primary, documented descriptor types (such as se_e2_a, se_atten, or dpa4) [1][2][3][4]. The official documentation focuses on Deep Potential (DP) descriptors that typically rely on smooth embedding networks [1][5][6]. Regarding the specific terms you mentioned: 1. Ghost atoms: While DeePMD-kit handles variable numbers of atoms via padding, this is generally referred to in the documentation as padding with "virtual particles" [1][4] or within the context of specific masking descriptors (e.g., se_a_mask) [1][7]. 2. Gaussian: Gaussian-based basis functions are available as an alternative to Bessel-like basis functions for radial distance representation in certain advanced descriptors like DPA4/SeZM [3]. 3. Sigma: This parameter is commonly associated with Gaussian width in radial basis functions (RBFs) in various machine learning potentials, but it is not a primary configuration keyword in the standard DeePMD-kit input files for the mainstream se_e2_a/se_atten descriptors [2][6]. If you are referencing a specific research implementation, custom extension, or an experimental branch (such as those exploring Gaussian-based descriptors or Gaussian-centered local connectivity), please consult the specific source code or paper associated with that development, as these terms are not part of the standard DeePMD-kit API [8][5]. top_results: [1][2][3][6]
Citations:
- 1: https://docs.deepmodeling.com/projects/deepmd/en/stable/model/overall.html
- 2: https://docs.deepmodeling.com/projects/deepmd/en/v3.0.2/model/train-se-atten.html
- 3: https://docs.deepmodeling.com/projects/deepmd/en/latest/model/dpa4.html
- 4: https://deepmd-kit.readthedocs.io/en/latest/model/overall.html
- 5: https://docs.deepmodeling.com/projects/deepmd/en/v2.2.8/api%5Fpy/deepmd.descriptor.html
- 6: https://docs.deepmodeling.org/projects/deepmd/en/master/model/train-se-e2-a.html
- 7: https://docs.deepmodeling.com/projects/deepmd/en/v2.2.5/api_py/deepmd.descriptor.html
- 8: https://docs.deepmodeling.com/projects/deepmd/en/v3.0.0b2/_modules/deepmd/pt/model/descriptor/gaussian_lcc.html
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect env_mat_vg.py around the cited lines (56-90-ish) and the full _gather_neighbor_sigma definition.
FILE="deepmd/pt/model/descriptor/env_mat_vg.py"
echo "=== Context around lines 50-110 in $FILE ==="
sed -n '1,140p' "$FILE" | nl -ba | sed -n '45,110p'
echo
echo "=== Find _gather_neighbor_sigma definition ==="
rg -n "_gather_neighbor_sigma" "$FILE" || true
# Print a wider window around any match
rg -n "_gather_neighbor_sigma" "$FILE" -n | cut -d: -f1 | while read -r start; do
s=$((start-30)); e=$((start+80));
echo "--- $FILE:${s}-${e} ---"
sed -n "${s},${e}p" "$FILE" | nl -ba
done
echo
echo "=== Search for sigma_ext usage in $FILE ==="
rg -n "sigma_ext" "$FILE" || true
echo
echo "=== Search for nlist_safe / torch.gather patterns in $FILE ==="
rg -n "nlist_safe|torch\.gather\(sigma_ext" "$FILE" || true
echo
echo "=== Inspect call site in se_a_vg.py where env_mat_vg is invoked ==="
SE="deepmd/pt/model/descriptor/se_a_vg.py"
rg -n "prod_env_mat_vg|env_mat_vg|_gather_neighbor_sigma" "$SE" || true
# Print likely relevant sections around where prod_env_mat_vg is called.
# Use a coarse window around matches.
for m in $(rg -n "prod_env_mat_vg" "$SE" | cut -d: -f1 | head -n 5); do
s=$((m-40)); e=$((m+140));
echo "--- $SE:${s}-${e} ---"
sed -n "${s},${e}p" "$SE" | nl -ba
done
echo
echo "=== Check whether mapping is used to remap nlist / indices before gather ==="
rg -n "mapping.*nlist|nlist.*mapping|index.*mapping|reshape.*mapping" "$SE" || trueRepository: deepmodeling/deepmd-kit
Length of output: 208
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="deepmd/pt/model/descriptor/env_mat_vg.py"
SE="deepmd/pt/model/descriptor/se_a_vg.py"
# Simple line-numbered printer without nl
print_range () {
local f="$1" s="$2" e="$3"
awk -v s="$s" -v e="$e" '{ if(NR>=s && NR<=e) printf "%6d | %s\n", NR, $0 }' "$f"
}
echo "=== Basic file stats ==="
wc -l "$FILE"
wc -l "$SE"
echo
echo "=== Show env_mat_vg.py around lines 40-110 (target snippet is ~66-76) ==="
print_range "$FILE" 40 110
echo
echo "=== Locate _gather_neighbor_sigma and print its full definition area ==="
rg -n "def _gather_neighbor_sigma" "$FILE" || true
# If found, print a generous window around the first match
start=$(rg -n "def _gather_neighbor_sigma" "$FILE" | head -n1 | cut -d: -f1 || true)
if [ -n "${start:-}" ]; then
print_range "$FILE" $((start-20)) $((start+140))
fi
echo
echo "=== Search sigma_ext / nlist_safe usage ==="
rg -n "sigma_ext|nlist_safe|torch\.gather\(sigma_ext" "$FILE" || true
echo
echo "=== Inspect call site in se_a_vg.py that invokes env_mat_vg / prod_env_mat_vg ==="
rg -n "prod_env_mat_vg|env_mat_vg" "$SE" || true
prod_line=$(rg -n "prod_env_mat_vg" "$SE" | head -n1 | cut -d: -f1 || true)
if [ -n "${prod_line:-}" ]; then
print_range "$SE" $((prod_line-60)) $((prod_line+180))
fi
echo
echo "=== Check whether mapping is applied to nlist/indices before env_mat_vg gather ==="
rg -n "mapping|nlist.*mapping|index.*mapping|gather\(.*mapping|remap" "$SE" || true
rg -n "mapping" "$FILE" || true
echo
echo "=== Show deletion/discarding of mapping in se_a_vg.py (around where mapping is referenced) ==="
rg -n "del .*mapping|mapping=" "$SE" || trueRepository: deepmodeling/deepmd-kit
Length of output: 18233
Fix ghost-neighbor sigma gathering for periodic VG environments
_gather_neighbor_sigma in deepmd/pt/model/descriptor/env_mat_vg.py builds sigma_ext with sigma_ext[:, :nloc] = sigma_loc, leaving indices >= nloc (ghost/periodic-image entries that can appear in nlist) as zero; torch.gather(sigma_ext, ...) therefore returns sigma_neighbor=0 for those neighbors. Since _make_env_mat_vg computes sigma_ij = sqrt(sigma_loc^2 + sigma_neighbor^2), ghost neighbors become sigma_ij = sigma_loc and never combine with the ghost’s source-atom sigma. If ghost neighbors are intended to inherit sigma from their mapped local source, fill sigma_ext using the mapping (or remap nlist indices) instead of leaving ghost slots at zero; otherwise document that ghost-neighbor sigma is intentionally zero.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@deepmd/pt/model/descriptor/env_mat_vg.py` around lines 66 - 76, The
_gather_neighbor_sigma routine currently leaves ghost/periodic-image slots in
sigma_ext as zeros which makes torch.gather return zero sigmas for ghost
neighbors; change the logic so ghost indices (positions >= nloc) are populated
by the source-atom sigma via the existing mapping (or remap nlist to
local-source indices before gathering) instead of zeros: use mapping to copy
sigma_loc into sigma_ext at ghost positions (or rewrite index = nlist_safe to
map ghost entries to the corresponding local atom indices) so sigma_nei reflects
the mapped source sigma; keep the original device/dtype handling and return
semantics unchanged.
Summary by CodeRabbit
New Features
Bug Fixes / Behavior
Tests